{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os\n",
    "import scipy.io as spio\n",
    "from matplotlib import pyplot as plt\n",
    "from imageio import imread"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Obs: If you only have one of the datasets (does not matter which one), just run all the notebook's cells and it will work just fine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define base paths for pascal the original VOC dataset training images\n",
    "base_dataset_dir_voc = '<path-to-voc-2012>/PascalVoc2012/train/VOC2012'\n",
    "images_folder_name_voc = \"JPEGImages/\"\n",
    "annotations_folder_name_voc = \"SegmentationClass_1D/\"\n",
    "images_dir_voc = os.path.join(base_dataset_dir_voc, images_folder_name_voc)\n",
    "annotations_dir_voc = os.path.join(base_dataset_dir_voc, annotations_folder_name_voc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define base paths for pascal augmented VOC images\n",
    "# download: http://home.bharathh.info/pubs/codes/SBD/download.html\n",
    "base_dataset_dir_aug_voc = '<pascal/augmented/VOC/images/path>/benchmark_RELEASE/dataset'\n",
    "images_folder_name_aug_voc = \"img/\"\n",
    "annotations_folder_name_aug_voc = \"cls/\"\n",
    "images_dir_aug_voc = os.path.join(base_dataset_dir_aug_voc, images_folder_name_aug_voc)\n",
    "annotations_dir_aug_voc = os.path.join(base_dataset_dir_aug_voc, annotations_folder_name_aug_voc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_files_list(base_dataset_dir, images_folder_name, annotations_folder_name, filename):\n",
    "    images_dir = os.path.join(base_dataset_dir, images_folder_name)\n",
    "    annotations_dir = os.path.join(base_dataset_dir, annotations_folder_name)\n",
    "\n",
    "    file = open(filename, 'r')\n",
    "    images_filename_list = [line for line in file]\n",
    "    return images_filename_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images_filename_list = get_files_list(base_dataset_dir_aug_voc, images_folder_name_aug_voc, annotations_folder_name_aug_voc, \"custom_train.txt\")\n",
    "print(\"Total number of training images:\", len(images_filename_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shuffle array and separate 10% to validation\n",
    "np.random.shuffle(images_filename_list)\n",
    "val_images_filename_list = images_filename_list[:int(0.10*len(images_filename_list))]\n",
    "train_images_filename_list = images_filename_list[int(0.10*len(images_filename_list)):]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"train set size:\", len(train_images_filename_list))\n",
    "print(\"val set size:\", len(val_images_filename_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_DATASET_DIR=\"./tfrecords/\"\n",
    "if not os.path.exists(TRAIN_DATASET_DIR):\n",
    "    os.mkdir(TRAIN_DATASET_DIR)\n",
    "    \n",
    "TRAIN_FILE = 'train.tfrecords'\n",
    "VALIDATION_FILE = 'validation.tfrecords'\n",
    "train_writer = tf.python_io.TFRecordWriter(os.path.join(TRAIN_DATASET_DIR,TRAIN_FILE))\n",
    "val_writer = tf.python_io.TFRecordWriter(os.path.join(TRAIN_DATASET_DIR,VALIDATION_FILE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _bytes_feature(value):\n",
    "    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
    "\n",
    "def _int64_feature(value):\n",
    "    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_annotation_from_mat_file(annotations_dir, image_name):\n",
    "    annotations_path = os.path.join(annotations_dir, (image_name.strip() + \".mat\"))\n",
    "    mat = spio.loadmat(annotations_path)\n",
    "    img = mat['GTcls']['Segmentation'][0][0]\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_tfrecord_dataset(filename_list, writer):\n",
    "\n",
    "    # create training tfrecord\n",
    "    read_imgs_counter = 0\n",
    "    for i, image_name in enumerate(filename_list):\n",
    "\n",
    "        try:\n",
    "            image_np = imread(os.path.join(images_dir_aug_voc, image_name.strip() + \".jpg\"))\n",
    "        except FileNotFoundError:\n",
    "            try:\n",
    "                # read from Pascal VOC path\n",
    "                image_np = imread(os.path.join(images_dir_voc, image_name.strip() + \".jpg\"))\n",
    "            except FileNotFoundError:\n",
    "                print(\"File:\",image_name.strip(),\"not found.\")\n",
    "                continue\n",
    "        try:\n",
    "            annotation_np = read_annotation_from_mat_file(annotations_dir_aug_voc, image_name)\n",
    "        except FileNotFoundError:\n",
    "            # read from Pascal VOC path\n",
    "            try:\n",
    "                annotation_np = imread(os.path.join(annotations_dir_voc, image_name.strip() + \".png\"))\n",
    "            except FileNotFoundError:\n",
    "                print(\"File:\",image_name.strip(),\"not found.\")\n",
    "                continue\n",
    "            \n",
    "        read_imgs_counter += 1\n",
    "        image_h = image_np.shape[0]\n",
    "        image_w = image_np.shape[1]\n",
    "\n",
    "        img_raw = image_np.tostring()\n",
    "        annotation_raw = annotation_np.tostring()\n",
    "\n",
    "        example = tf.train.Example(features=tf.train.Features(feature={\n",
    "                'height': _int64_feature(image_h),\n",
    "                'width': _int64_feature(image_w),\n",
    "                'image_raw': _bytes_feature(img_raw),\n",
    "                'annotation_raw': _bytes_feature(annotation_raw)}))\n",
    "\n",
    "        writer.write(example.SerializeToString())\n",
    "    \n",
    "    print(\"End of TfRecord. Total of image written:\", read_imgs_counter)\n",
    "    writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create training dataset\n",
    "create_tfrecord_dataset(train_images_filename_list, train_writer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create validation dataset\n",
    "create_tfrecord_dataset(val_images_filename_list, val_writer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
